{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Zero-shot object detection"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "従来、[オブジェクト検出](https://huggingface.co/docs/transformers/main/ja/tasks/object_detection) に使用されるモデルには、トレーニング用のラベル付き画像データセットが必要でした。\n",
    "トレーニング データからのクラスのセットの検出に限定されます。\n",
    "\n",
    "ゼロショットオブジェクト検出は、別のアプローチを使用する [OWL-ViT](https://huggingface.co/docs/transformers/main/ja/tasks/../model_doc/owlvit) モデルによってサポートされています。 OWL-ViT\n",
    "オープン語彙オブジェクト検出器です。これは、フリーテキストクエリに基づいて画像内のオブジェクトを検出できることを意味します。\n",
    "ラベル付きデータセットでモデルを微調整する必要性。\n",
    "\n",
    "OWL-ViTは、マルチモーダル表現を利用してオープン語彙の検出を実行します。 [CLIP](https://huggingface.co/docs/transformers/main/ja/tasks/../model_doc/clip) とを組み合わせます。\n",
    "軽量のオブジェクト分類および位置特定ヘッド。オープン語彙の検出は、CLIP のテキスト エンコーダーを使用してフリーテキスト クエリを埋め込み、それらをオブジェクト分類およびローカリゼーション ヘッドへの入力として使用することによって実現されます。\n",
    "画像とそれに対応するテキストの説明を関連付け、ViT は画像パッチを入力として処理します。作家たち\n",
    "のOWL-ViTは、まずCLIPをゼロからトレーニングし、次に標準の物体検出データセットを使用してOWL-ViTをエンドツーエンドで微調整しました。\n",
    "二部マッチング損失。\n",
    "\n",
    "このアプローチを使用すると、モデルはラベル付きデータセットで事前にトレーニングしなくても、テキストの説明に基づいてオブジェクトを検出できます。\n",
    "\n",
    "このガイドでは、OWL-ViT の使用方法を学習します。\n",
    "- テキストプロンプトに基づいてオブジェクトを検出します\n",
    "- バッチオブジェクト検出用\n",
    "- 画像誘導物体検出用\n",
    "\n",
    "始める前に、必要なライブラリがすべてインストールされていることを確認してください。\n",
    "\n",
    "```bash\n",
    "pip install -q transformers\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero-shot object detection pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OWL-ViTによる推論を試す最も簡単な方法は、OWL-ViTを[pipeline()](https://huggingface.co/docs/transformers/main/ja/main_classes/pipelines#transformers.pipeline)で使用することです。パイプラインをインスタンス化する\n",
    "[Hugging Face Hub のチェックポイント](https://huggingface.co/models?other=owlvit) からのゼロショット オブジェクト検出の場合:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "checkpoint = \"google/owlvit-base-patch32\"\n",
    "detector = pipeline(model=checkpoint, task=\"zero-shot-object-detection\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、物体を検出したい画像を選択します。ここでは、宇宙飛行士アイリーン・コリンズの画像を使用します。\n",
    "[NASA](https://www.nasa.gov/multimedia/imagegallery/index.html) Great Images データセットの一部。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import skimage\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "image = skimage.data.astronaut()\n",
    "image = Image.fromarray(np.uint8(image)).convert(\"RGB\")\n",
    "\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_1.png\" alt=\"Astronaut Eileen Collins\"/>\n",
    "</div>\n",
    "\n",
    "検索する画像と候補オブジェクトのラベルをパイプラインに渡します。\n",
    "ここでは画像を直接渡します。他の適切なオプションには、画像へのローカル パスまたは画像 URL が含まれます。また、画像をクエリするすべてのアイテムのテキスト説明も渡します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.3571370542049408,\n",
       "  'label': 'human face',\n",
       "  'box': {'xmin': 180, 'ymin': 71, 'xmax': 271, 'ymax': 178}},\n",
       " {'score': 0.28099656105041504,\n",
       "  'label': 'nasa badge',\n",
       "  'box': {'xmin': 129, 'ymin': 348, 'xmax': 206, 'ymax': 427}},\n",
       " {'score': 0.2110239565372467,\n",
       "  'label': 'rocket',\n",
       "  'box': {'xmin': 350, 'ymin': -1, 'xmax': 468, 'ymax': 288}},\n",
       " {'score': 0.13790413737297058,\n",
       "  'label': 'star-spangled banner',\n",
       "  'box': {'xmin': 1, 'ymin': 1, 'xmax': 105, 'ymax': 509}},\n",
       " {'score': 0.11950037628412247,\n",
       "  'label': 'nasa badge',\n",
       "  'box': {'xmin': 277, 'ymin': 338, 'xmax': 327, 'ymax': 380}},\n",
       " {'score': 0.10649408400058746,\n",
       "  'label': 'rocket',\n",
       "  'box': {'xmin': 358, 'ymin': 64, 'xmax': 424, 'ymax': 280}}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = detector(\n",
    "    image,\n",
    "    candidate_labels=[\"human face\", \"rocket\", \"nasa badge\", \"star-spangled banner\"],\n",
    ")\n",
    "predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "予測を視覚化してみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import ImageDraw\n",
    "\n",
    "draw = ImageDraw.Draw(image)\n",
    "\n",
    "for prediction in predictions:\n",
    "    box = prediction[\"box\"]\n",
    "    label = prediction[\"label\"]\n",
    "    score = prediction[\"score\"]\n",
    "\n",
    "    xmin, ymin, xmax, ymax = box.values()\n",
    "    draw.rectangle((xmin, ymin, xmax, ymax), outline=\"red\", width=1)\n",
    "    draw.text((xmin, ymin), f\"{label}: {round(score,2)}\", fill=\"white\")\n",
    "\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_2.png\" alt=\"Visualized predictions on NASA image\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Text-prompted zero-shot object detection by hand"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ゼロショット物体検出パイプラインの使用方法を確認したので、同じことを再現してみましょう。\n",
    "手動で結果を取得します。\n",
    "\n",
    "まず、[Hugging Face Hub のチェックポイント](https://huggingface.co/models?other=owlvit) からモデルと関連プロセッサをロードします。\n",
    "ここでは、前と同じチェックポイントを使用します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection\n",
    "\n",
    "model = AutoModelForZeroShotObjectDetection.from_pretrained(checkpoint)\n",
    "processor = AutoProcessor.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "気分を変えて、別の画像を撮ってみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "\n",
    "url = \"https://unsplash.com/photos/oj0zeY2Ltk4/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MTR8fHBpY25pY3xlbnwwfHx8fDE2Nzc0OTE1NDk&force=true&w=640\"\n",
    "im = Image.open(requests.get(url, stream=True).raw)\n",
    "im"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_3.png\" alt=\"Beach photo\"/>\n",
    "</div>\n",
    "\n",
    "プロセッサを使用してモデルの入力を準備します。プロセッサーは、\n",
    "サイズ変更と正規化によるモデルの画像と、テキスト入力を処理する [CLIPTokenizer](https://huggingface.co/docs/transformers/main/ja/model_doc/clip#transformers.CLIPTokenizer) です。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_queries = [\"hat\", \"book\", \"sunglasses\", \"camera\"]\n",
    "inputs = processor(text=text_queries, images=im, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "入力をモデルに渡し、後処理し、結果を視覚化します。以前は画像プロセッサによって画像のサイズが変更されていたため、\n",
    "それらをモデルにフィードするには、`post_process_object_detection()` メソッドを使用して、予測された境界を確認する必要があります。\n",
    "ボックスは元の画像を基準とした正しい座標を持ちます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model(**inputs)\n",
    "    target_sizes = torch.tensor([im.size[::-1]])\n",
    "    results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]\n",
    "\n",
    "draw = ImageDraw.Draw(im)\n",
    "\n",
    "scores = results[\"scores\"].tolist()\n",
    "labels = results[\"labels\"].tolist()\n",
    "boxes = results[\"boxes\"].tolist()\n",
    "\n",
    "for box, score, label in zip(boxes, scores, labels):\n",
    "    xmin, ymin, xmax, ymax = box\n",
    "    draw.rectangle((xmin, ymin, xmax, ymax), outline=\"red\", width=1)\n",
    "    draw.text((xmin, ymin), f\"{text_queries[label]}: {round(score,2)}\", fill=\"white\")\n",
    "\n",
    "im"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_4.png\" alt=\"Beach photo with detected objects\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch processing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "複数の画像セットとテキスト クエリを渡して、複数の画像内の異なる (または同じ) オブジェクトを検索できます。\n",
    "宇宙飛行士の画像とビーチの画像を組み合わせてみましょう。\n",
    "バッチ処理の場合、テキスト クエリをネストされたリストとしてプロセッサに渡し、画像を PIL イメージのリストとして渡す必要があります。\n",
    "PyTorch テンソル、または NumPy 配列。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images = [image, im]\n",
    "text_queries = [\n",
    "    [\"human face\", \"rocket\", \"nasa badge\", \"star-spangled banner\"],\n",
    "    [\"hat\", \"book\", \"sunglasses\", \"camera\"],\n",
    "]\n",
    "inputs = processor(text=text_queries, images=images, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以前は後処理のために単一の画像のサイズをテンソルとして渡していましたが、タプルを渡すこともできます。\n",
    "複数の画像のタプルのリスト。 2 つの例の予測を作成し、2 番目の例 (`image_idx = 1`) を視覚化しましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    outputs = model(**inputs)\n",
    "    target_sizes = [x.size[::-1] for x in images]\n",
    "    results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)\n",
    "\n",
    "image_idx = 1\n",
    "draw = ImageDraw.Draw(images[image_idx])\n",
    "\n",
    "scores = results[image_idx][\"scores\"].tolist()\n",
    "labels = results[image_idx][\"labels\"].tolist()\n",
    "boxes = results[image_idx][\"boxes\"].tolist()\n",
    "\n",
    "for box, score, label in zip(boxes, scores, labels):\n",
    "    xmin, ymin, xmax, ymax = box\n",
    "    draw.rectangle((xmin, ymin, xmax, ymax), outline=\"red\", width=1)\n",
    "    draw.text((xmin, ymin), f\"{text_queries[image_idx][label]}: {round(score,2)}\", fill=\"white\")\n",
    "\n",
    "images[image_idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_4.png\" alt=\"Beach photo with detected objects\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Image-guided object detection"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "テキストクエリによるゼロショットオブジェクト検出に加えて、OWL-ViTは画像ガイドによるオブジェクト検出を提供します。これはつまり\n",
    "画像クエリを使用して、ターゲット画像内の類似したオブジェクトを検索できます。\n",
    "テキスト クエリとは異なり、使用できるサンプル画像は 1 つだけです。\n",
    "\n",
    "対象画像としてソファに2匹の猫がいる画像と、1匹の猫の画像を撮影しましょう\n",
    "クエリとして:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
    "image_target = Image.open(requests.get(url, stream=True).raw)\n",
    "\n",
    "query_url = \"http://images.cocodataset.org/val2017/000000524280.jpg\"\n",
    "query_image = Image.open(requests.get(query_url, stream=True).raw)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "画像を簡単に見てみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots(1, 2)\n",
    "ax[0].imshow(image_target)\n",
    "ax[1].imshow(query_image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_5.png\" alt=\"Cats\"/>\n",
    "</div>\n",
    "\n",
    "前処理ステップでは、テキスト クエリの代わりに `query_images` を使用する必要があります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = processor(images=image_target, query_images=query_image, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "予測の場合、入力をモデルに渡す代わりに、`image_guided_detection()` に渡します。予測を描く\n",
    "ラベルがないことを除いては以前と同様です。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    outputs = model.image_guided_detection(**inputs)\n",
    "    target_sizes = torch.tensor([image_target.size[::-1]])\n",
    "    results = processor.post_process_image_guided_detection(outputs=outputs, target_sizes=target_sizes)[0]\n",
    "\n",
    "draw = ImageDraw.Draw(image_target)\n",
    "\n",
    "scores = results[\"scores\"].tolist()\n",
    "boxes = results[\"boxes\"].tolist()\n",
    "\n",
    "for box, score, label in zip(boxes, scores, labels):\n",
    "    xmin, ymin, xmax, ymax = box\n",
    "    draw.rectangle((xmin, ymin, xmax, ymax), outline=\"white\", width=4)\n",
    "\n",
    "image_target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/zero-sh-obj-detection_6.png\" alt=\"Cats with bounding boxes\"/>\n",
    "</div>\n",
    "\n",
    "OWL-ViTによる推論をインタラクティブに試したい場合は、このデモをチェックしてください。\n",
    "\n",
    "<iframe\n",
    "\tsrc=\"https://adirik-owl-vit.hf.space\"\n",
    "\tframeborder=\"0\"\n",
    "\twidth=\"850\"\n",
    "\theight=\"450\"\n",
    "></iframe>"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
